import base64
import litellm
import instructor
from typing import Type
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException

import logging
from tenacity import (
    retry,
    stop_after_delay,
    wait_exponential_jitter,
    retry_if_not_exception_type,
    before_sleep_log,
)

from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
    LLMInterface,
)
from cognee.infrastructure.llm.exceptions import (
    ContentPolicyFilterError,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.observability.get_observe import get_observe
from cognee.shared.logging_utils import get_logger

logger = get_logger()

observe = get_observe()


class OpenAIAdapter(LLMInterface):
    """
    Adapter for OpenAI's GPT-3, GPT-4 API.

    Public methods:

    - acreate_structured_output
    - create_structured_output
    - create_transcript
    - transcribe_image
    - show_prompt

    Instance variables:

    - name
    - model
    - api_key
    - api_version
    - MAX_RETRIES
    """

    name = "OpenAI"
    model: str
    api_key: str
    api_version: str

    MAX_RETRIES = 5

    """Adapter for OpenAI's GPT-3, GPT=4 API"""

    def __init__(
        self,
        api_key: str,
        endpoint: str,
        api_version: str,
        model: str,
        transcription_model: str,
        max_completion_tokens: int,
        streaming: bool = False,
        fallback_model: str = None,
        fallback_api_key: str = None,
        fallback_endpoint: str = None,
    ):
        # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
        #       Make sure all new gpt models will work with this mode as well.
        if "gpt-5" in model:
            self.aclient = instructor.from_litellm(
                litellm.acompletion, mode=instructor.Mode.JSON_SCHEMA
            )
            self.client = instructor.from_litellm(
                litellm.completion, mode=instructor.Mode.JSON_SCHEMA
            )
        else:
            self.aclient = instructor.from_litellm(litellm.acompletion)
            self.client = instructor.from_litellm(litellm.completion)

        self.transcription_model = transcription_model
        self.model = model
        self.api_key = api_key
        self.endpoint = endpoint
        self.api_version = api_version
        self.max_completion_tokens = max_completion_tokens
        self.streaming = streaming

        self.fallback_model = fallback_model
        self.fallback_api_key = fallback_api_key
        self.fallback_endpoint = fallback_endpoint

    @observe(as_type="generation")
    @retry(
        stop=stop_after_delay(128),
        wait=wait_exponential_jitter(2, 128),
        retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
        before_sleep=before_sleep_log(logger, logging.DEBUG),
        reraise=True,
    )
    async def acreate_structured_output(
        self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
    ) -> BaseModel:
        """
        Generate a response from a user query.

        This method asynchronously creates structured output by sending a request to the OpenAI
        API using the provided parameters to generate a completion based on the user input and
        system prompt.

        Parameters:
        -----------

            - text_input (str): The input text provided by the user for generating a response.
            - system_prompt (str): The system's prompt to guide the model's response.
            - response_model (Type[BaseModel]): The expected model type for the response.

        Returns:
        --------

            - BaseModel: A structured output generated by the model, returned as an instance of
              BaseModel.
        """

        try:
            return await self.aclient.chat.completions.create(
                model=self.model,
                messages=[
                    {
                        "role": "user",
                        "content": f"""{text_input}""",
                    },
                    {
                        "role": "system",
                        "content": system_prompt,
                    },
                ],
                api_key=self.api_key,
                api_base=self.endpoint,
                api_version=self.api_version,
                response_model=response_model,
                max_retries=self.MAX_RETRIES,
            )
        except (
            ContentFilterFinishReasonError,
            ContentPolicyViolationError,
            InstructorRetryException,
        ) as e:
            if not (self.fallback_model and self.fallback_api_key):
                raise e
            try:
                return await self.aclient.chat.completions.create(
                    model=self.fallback_model,
                    messages=[
                        {
                            "role": "user",
                            "content": f"""{text_input}""",
                        },
                        {
                            "role": "system",
                            "content": system_prompt,
                        },
                    ],
                    api_key=self.fallback_api_key,
                    # api_base=self.fallback_endpoint,
                    response_model=response_model,
                    max_retries=self.MAX_RETRIES,
                )
            except (
                ContentFilterFinishReasonError,
                ContentPolicyViolationError,
                InstructorRetryException,
            ) as error:
                if (
                    isinstance(error, InstructorRetryException)
                    and "content management policy" not in str(error).lower()
                ):
                    raise error
                else:
                    raise ContentPolicyFilterError(
                        f"The provided input contains content that is not aligned with our content policy: {text_input}"
                    ) from error

    @observe
    @retry(
        stop=stop_after_delay(128),
        wait=wait_exponential_jitter(2, 128),
        retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
        before_sleep=before_sleep_log(logger, logging.DEBUG),
        reraise=True,
    )
    def create_structured_output(
        self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
    ) -> BaseModel:
        """
        Generate a response from a user query.

        This method creates structured output by sending a synchronous request to the OpenAI API
        using the provided parameters to generate a completion based on the user input and
        system prompt.

        Parameters:
        -----------

            - text_input (str): The input text provided by the user for generating a response.
            - system_prompt (str): The system's prompt to guide the model's response.
            - response_model (Type[BaseModel]): The expected model type for the response.

        Returns:
        --------

            - BaseModel: A structured output generated by the model, returned as an instance of
              BaseModel.
        """

        return self.client.chat.completions.create(
            model=self.model,
            messages=[
                {
                    "role": "user",
                    "content": f"""{text_input}""",
                },
                {
                    "role": "system",
                    "content": system_prompt,
                },
            ],
            api_key=self.api_key,
            api_base=self.endpoint,
            api_version=self.api_version,
            response_model=response_model,
            max_retries=self.MAX_RETRIES,
        )

    @retry(
        stop=stop_after_delay(128),
        wait=wait_exponential_jitter(2, 128),
        retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
        before_sleep=before_sleep_log(logger, logging.DEBUG),
        reraise=True,
    )
    async def create_transcript(self, input):
        """
        Generate an audio transcript from a user query.

        This method creates a transcript from the specified audio file, raising a
        FileNotFoundError if the file does not exist. The audio file is processed and the
        transcription is retrieved from the API.

        Parameters:
        -----------

            - input: The path to the audio file that needs to be transcribed.

        Returns:
        --------

            The generated transcription of the audio file.
        """

        async with open_data_file(input, mode="rb") as audio_file:
            transcription = litellm.transcription(
                model=self.transcription_model,
                file=audio_file,
                api_key=self.api_key,
                api_base=self.endpoint,
                api_version=self.api_version,
                max_retries=self.MAX_RETRIES,
            )

        return transcription

    @retry(
        stop=stop_after_delay(128),
        wait=wait_exponential_jitter(2, 128),
        retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
        before_sleep=before_sleep_log(logger, logging.DEBUG),
        reraise=True,
    )
    async def transcribe_image(self, input) -> BaseModel:
        """
        Generate a transcription of an image from a user query.

        This method encodes the image and sends a request to the OpenAI API to obtain a
        description of the contents of the image.

        Parameters:
        -----------

            - input: The path to the image file that needs to be transcribed.

        Returns:
        --------

            - BaseModel: A structured output generated by the model, returned as an instance of
              BaseModel.
        """
        async with open_data_file(input, mode="rb") as image_file:
            encoded_image = base64.b64encode(image_file.read()).decode("utf-8")

        return litellm.completion(
            model=self.model,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": "What's in this image?",
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{encoded_image}",
                            },
                        },
                    ],
                }
            ],
            api_key=self.api_key,
            api_base=self.endpoint,
            api_version=self.api_version,
            max_completion_tokens=300,
            max_retries=self.MAX_RETRIES,
        )
